{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jsonlines\n",
    "import json\n",
    "from tsllm.envs import get_env_answer_checker\n",
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deduplication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5199.825371336812\n",
      "89676 -> 78732 after deduplicate.\n",
      "correct: 0.732091144642585\n"
     ]
    }
   ],
   "source": [
    "input_file_path = \"train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12.jsonl\"\n",
    "output_file_path = \"train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-dedup.jsonl\"\n",
    "\n",
    "check_fn = get_env_answer_checker(\"gsm8k\")\n",
    "\n",
    "total_tokens = 0\n",
    "cnt = 0\n",
    "dedup_objs = []\n",
    "cnt_before_dedup, cnt_after_dedup = 0, 0\n",
    "correct_cnt = 0\n",
    "with jsonlines.open(input_file_path, \"r\") as reader:\n",
    "    for obj in reader:\n",
    "        total_tokens += obj[\"result\"][\"#token\"]\n",
    "        cnt += 1\n",
    "        texts = set()\n",
    "        new_output_list = []\n",
    "        for o in obj[\"output\"]:\n",
    "            cnt_before_dedup += 1\n",
    "            txt = o[\"text\"]\n",
    "            if txt not in texts:\n",
    "                cnt_after_dedup += 1\n",
    "                o[\"correct\"] = check_fn(obj[\"question\"], obj[\"groundtruth\"], txt)\n",
    "                if o[\"correct\"]:\n",
    "                    correct_cnt += 1\n",
    "                new_output_list.append(o)\n",
    "                texts.add(txt)\n",
    "        obj.pop(\"output\")\n",
    "        obj[\"answer\"] = new_output_list\n",
    "        dedup_objs.append(obj)\n",
    "\n",
    "print(total_tokens / cnt)\n",
    "print(\"{} -> {} after deduplicate.\".format(cnt_before_dedup, cnt_after_dedup))\n",
    "print(\"correct: {}\".format(correct_cnt / cnt_after_dedup))\n",
    "\n",
    "\n",
    "with jsonlines.open(output_file_path, \"w\") as writer:\n",
    "    for obj in dedup_objs:\n",
    "        writer.write(obj)\n",
    "\n",
    "del input_file_path, output_file_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Merge previous rollout data and current rollout data for policy training\n",
    "\n",
    "set `input_file_path_0` to be the path of previous rollout data, `input_file_path_1` to be the path of current rollout_data\n",
    "\n",
    "set `output_file_path` to be where you store the merged data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ADD 6504 new instances\n",
      "TOTAL DATA: 64134\n"
     ]
    }
   ],
   "source": [
    "input_file_path_1 = \"train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-dedup.jsonl\"\n",
    "input_file_path_0 = \"tslmm/envs/gsm8k/train_data/sft_init.jsonl\"\n",
    "output_file_path = \"train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-merge_train-dedup.jsonl\"\n",
    "\n",
    "obj_dict = {}\n",
    "cnt = 0\n",
    "total_cnt = 0\n",
    "with jsonlines.open(input_file_path_1, \"r\") as reader:\n",
    "    for obj in reader:\n",
    "        obj_dict[obj[\"i\"]] = obj\n",
    "\n",
    "with jsonlines.open(input_file_path_0, \"r\") as reader:\n",
    "    for i, obj in enumerate(reader):\n",
    "        obj_to_merge = obj_dict[i]\n",
    "        assert obj_to_merge[\"question\"] == obj[\"question\"]\n",
    "        current_texts = set([o[\"text\"] for o in obj_to_merge[\"answer\"]])\n",
    "\n",
    "        come_in_output = obj[\"answer\"][0]\n",
    "        if come_in_output[\"text\"] not in current_texts:\n",
    "            obj_to_merge[\"answer\"].append(come_in_output)\n",
    "            cnt += 1\n",
    "        total_cnt += len([x for x in obj_to_merge[\"answer\"] if x[\"correct\"]])\n",
    "\n",
    "print(\"ADD {} new instances\".format(cnt))\n",
    "print(\"TOTAL DATA: {}\".format(total_cnt))\n",
    "\n",
    "# SL sft training data\n",
    "with jsonlines.open(output_file_path, \"w\") as writer:\n",
    "    for obj in obj_dict.values():\n",
    "        writer.write(obj)\n",
    "\n",
    "del input_file_path, output_file_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Merge previous rollout data and current rollout data for critic training\n",
    "\n",
    "set `input_file_path_0` to be the path of previous rollout data, `input_file_path_1` to be the path of current rollout_data\n",
    "\n",
    "set `output_file_path` to be where you store the merged data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ADD 280037/345945 NEW DATA, NOW 358769\n"
     ]
    }
   ],
   "source": [
    "input_file_path_0 = \"tsllm/offline_rl/gsm8k_data/processed/gsm8k_train_cot_sample_sft_k100_merged_dedup_sample17x3.jsonl\"\n",
    "input_file_path_1 = \"train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-dedup.jsonl\"\n",
    "output_file_path = \"train_mcts_scripts/gsm8k/iteration1/mcts_rollout-k12-value-sl_train-dedup.jsonl\"\n",
    "\n",
    "seed = 1\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "obj_dict = {}\n",
    "with jsonlines.open(input_file_path_1, \"r\") as reader:\n",
    "    for obj in reader:\n",
    "        obj_dict[obj[\"i\"]] = obj\n",
    "        \n",
    "\n",
    "cnt = 0\n",
    "total_cnt = 0\n",
    "merged_dedup_cnt = 0\n",
    "K = 51 - 12\n",
    "subsample_obj_dict = {}\n",
    "with jsonlines.open(input_file_path_0, \"r\") as reader:\n",
    "    for i, obj in enumerate(reader):\n",
    "        assert obj[\"question\"] == obj_dict[i][\"question\"]\n",
    "        current_texts = set([o[\"text\"] for o in obj_dict[i][\"answer\"]])\n",
    "        total_cnt += len(obj[\"answer\"])        \n",
    "        print(len(obj[\"answer\"]), end=\"\\r\")\n",
    "        if len(obj[\"answer\"]) > K:\n",
    "            subsample_list = np.random.choice(obj[\"answer\"], K, replace=False)\n",
    "        else:\n",
    "            subsample_list = obj[\"answer\"]\n",
    "\n",
    "        for o in subsample_list:\n",
    "            if o[\"text\"] not in current_texts:\n",
    "                current_texts.add(o[\"text\"])\n",
    "                obj_dict[i][\"answer\"].append(o)\n",
    "                cnt += 1\n",
    "        merged_dedup_cnt += len(obj_dict[i][\"answer\"])\n",
    "print(\"ADD {}/{} NEW DATA, NOW {}\".format(cnt, total_cnt, merged_dedup_cnt))\n",
    "\n",
    "# SL sft training data\n",
    "with jsonlines.open(output_file_path, \"w\") as writer:\n",
    "    for obj in obj_dict.values():\n",
    "        writer.write(obj)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mcts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
